/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

#ifndef _THRIFT_TRANSPORT_TSSLSOCKET_H_
#define _THRIFT_TRANSPORT_TSSLSOCKET_H_ 1

#include <string>
#include <boost/shared_ptr.hpp>
#include <boost/asio.hpp>
#include <boost/asio/ssl.hpp>
#include <thrift/transport/TTransport.h>
#include <thrift/transport/TSSLContext.h>
#include <thrift/transport/TVirtualTransport.h>

namespace apache {
namespace thrift {
namespace transport {

class TSSLContext;

/**
 * OpenSSL implementation for SSL socket interface.
 */
class TSSLSocket : public TVirtualTransport<TSSLSocket>
{
public:
  ~TSSLSocket();
  /**
   * TTransport interface.
   */
  bool isOpen();
  bool peek();
  void open();
  void close();
  uint32_t read(uint8_t* buf, uint32_t len);
  uint32_t readAll(uint8_t* buf, uint32_t len);
  void write(const uint8_t* buf, uint32_t len);

  /**
  * Get the host that the socket is connected to
  *
  * @return string host identifier
  */
  std::string getHost();

  /**
  * Get the port that the socket is connected to
  *
  * @return int port number
  */
  int getPort();

  /**
  * Set the host that socket will connect to
  *
  * @param host host identifier
  */
  void setHost(std::string host);

  /**
  * Set the port that socket will connect to
  *
  * @param port port number
  */
  void setPort(int port);

  /**
  * Controls whether the linger option is set on the socket.
  *
  * @param on      Whether SO_LINGER is on
  * @param linger  If linger is active, the number of seconds to linger for
  */
  void setLinger(bool on, int linger);

  /**
  * Whether to enable/disable Nagle's algorithm.
  *
  * @param noDelay Whether or not to disable the algorithm.
  * @return
  */
  void setNoDelay(bool noDelay);

  /**
  * Set the connect timeout
  */
  void setConnTimeout(int ms);

  /**
  * Set the receive timeout
  */
  void setRecvTimeout(int ms);

  /**
  * Set the send timeout
  */
  void setSendTimeout(int ms);

  /**
  * Set the max number of recv retries in case of an THRIFT_EAGAIN
  * error
  */
  void setMaxRecvRetries(int maxRecvRetries);

  /**
  * Set SO_KEEPALIVE
  */
  void setKeepAlive(bool keepAlive);

  /**
  * Get socket information formatted as a string <Host: x Port: x>
  */
  std::string getSocketInfo();

  /**
  * Returns the DNS name of the host to which the socket is connected
  */
  std::string getPeerHost();

  /**
  * Returns the address of the host to which the socket is connected
  */
  std::string getPeerAddress();

  /**
  * Returns the port of the host to which the socket is connected
  **/
  int getPeerPort();

  /**
  * Get the origin the socket is connected to
  *
  * @return string peer host identifier and port
  */
  const std::string getOrigin();
  
  /**
  * Set whether to use client or server side SSL handshake protocol.
  *
  * @param flag  Use server side handshake protocol if true.
  */
  void server(bool flag)
  {
    server_ = flag;
  }
  
  /**
   * Determine whether the SSL socket is server or client mode.
   */
  bool server() const
  {
    return server_;
  }
  
  boost::asio::ssl::stream<boost::asio::ip::tcp::socket>::lowest_layer_type& getSocket()
  {
    return sslsocket_.lowest_layer();
  }
protected:
  /**
   * Constructor.
   */
  TSSLSocket(const boost::shared_ptr<TSSLContext>& ctx);
 
  /**
   * Constructor.
   *
   * @param host  Remote host name
   * @param port  Remote port number
   */
  TSSLSocket(const boost::shared_ptr<TSSLContext>& ctx, const std::string& host, int port);

  /** connect, called by open */
  void openConnection();
  void handshake(boost::asio::ssl::stream_base::handshake_type type);
  bool verify_certificate(bool preverified, boost::asio::ssl::verify_context& ctx);
 
  bool server_;
  boost::shared_ptr<TSSLContext> sslctx_;
  boost::asio::ssl::stream<boost::asio::ip::tcp::socket> sslsocket_;
  boost::asio::deadline_timer deadline_;

  /** Host to connect to */
  std::string host_;

  /** Peer hostname */
  std::string peerHost_;

  /** Peer address */
  std::string peerAddress_;

  /** Peer port */
  int peerPort_;

  /** Port number to connect on */
  int port_;

  /** Connect timeout in ms */
  int connTimeout_;

  /** Send timeout in ms */
  int sendTimeout_;

  /** Recv timeout in ms */
  int recvTimeout_;

  /** Keep alive on */
  bool keepAlive_;

  /** Linger on */
  bool lingerOn_;

  /** Linger val */
  int lingerVal_;

  /** Nodelay */
  bool noDelay_;

  /** Recv EGAIN retries */
  int maxRecvRetries_;

  friend class TSSLSocketFactory;
};

/**
 * SSL socket factory. SSL sockets should be created via SSL factory.
 */
class TSSLSocketFactory
{
public:
  /**
   * Constructor/Destructor
   *
   * @param protocol The SSL/TLS protocol to use.
   */
  TSSLSocketFactory(const boost::shared_ptr<TSSLContext>& ctx);
  
  virtual ~TSSLSocketFactory();
  
  /**
   * Create an instance of TSSLSocket with a fresh new socket.
   */
  virtual boost::shared_ptr<TSSLSocket> createSocket();
  
  /**
  * Create an instance of TSSLSocket.
  *
  * @param host  Remote host to be connected to
  * @param port  Remote port to be connected to
  */
  virtual boost::shared_ptr<TSSLSocket> createSocket(const std::string& host, int port);
  
  /**
   * Set/Unset server mode.
   *
   * @param flag  Server mode if true
   */
  virtual void server(bool flag)
  {
    server_ = flag;
  }
  
  /**
   * Determine whether the socket is in server or client mode.
   *
   * @return true, if server mode, or, false, if client mode
   */
  virtual bool server() const
  {
    return server_;
  }
  

protected:
  boost::shared_ptr<TSSLContext> ctx_;

private:
  bool server_;
};

/**
 * SSL exception.
 */
class TSSLException : public TTransportException
{
public:
  TSSLException(const std::string& message)
    : TTransportException(TTransportException::INTERNAL_ERROR, message)
  {
  
  }

  virtual const char* what() const throw()
  {
    if (message_.empty()) {
      return "TSSLException";
    }
    else {
      return message_.c_str();
    }
  }
};

}
}
}

#endif
